{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9210ed0e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-05T14:30:18.532579Z",
     "start_time": "2023-09-05T14:30:17.311863Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa3e4698",
   "metadata": {},
   "source": [
    "## basics"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5596316",
   "metadata": {},
   "source": [
    "$$\n",
    "\\begin{split}\n",
    "PE(t,2i)&=\\sin\\left(\\frac{t}{10000^{\\frac{2i}{d_{model}}}}\\right)\\\\\n",
    "PE(t,2i+1)&=\\cos\\left(\\frac{t}{10000^{\\frac{2i}{d_{model}}}}\\right)\\\\\n",
    "\\Downarrow\\\\\n",
    "PE(t,i)&=\\sin\\left(\\frac{t}{10000^{\\frac{i}{d_{model}}}}\\right), \\quad \\text{i is even}\\\\\n",
    "PE(t,i)&=\\cos\\left(\\frac{t}{10000^{\\frac{i-1}{d_{model}}}}\\right), \\quad \\text{i is odd}\\\\\n",
    "\\end{split}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db00dfc8",
   "metadata": {},
   "source": [
    "## cases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d6fcb68b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-05T14:33:48.081571Z",
     "start_time": "2023-09-05T14:33:48.076043Z"
    }
   },
   "outputs": [],
   "source": [
    "max_sequence_length = 10\n",
    "d_model = 6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d83c206e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-05T14:34:14.310316Z",
     "start_time": "2023-09-05T14:34:14.301126Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0., 2., 4.])\n",
      "tensor([1., 3., 5.])\n"
     ]
    }
   ],
   "source": [
    "even_i = torch.arange(0, d_model, 2).float()\n",
    "# even_i.dtype\n",
    "print(even_i)\n",
    "odd_i = torch.arange(1, d_model, 2).float()\n",
    "print(odd_i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ebc0344a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-05T14:34:29.405863Z",
     "start_time": "2023-09-05T14:34:29.388876Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.float32"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "even_i.dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1a462d34",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-05T14:34:35.013766Z",
     "start_time": "2023-09-05T14:34:35.005271Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.],\n",
       "        [1.],\n",
       "        [2.],\n",
       "        [3.],\n",
       "        [4.],\n",
       "        [5.],\n",
       "        [6.],\n",
       "        [7.],\n",
       "        [8.],\n",
       "        [9.]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "position = torch.arange(max_sequence_length, dtype=torch.float32).reshape(-1, 1)\n",
    "position"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4e2e461d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-05T14:34:59.380826Z",
     "start_time": "2023-09-05T14:34:59.371202Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0000,  0.0000,  0.0000],\n",
       "        [ 0.8415,  0.0464,  0.0022],\n",
       "        [ 0.9093,  0.0927,  0.0043],\n",
       "        [ 0.1411,  0.1388,  0.0065],\n",
       "        [-0.7568,  0.1846,  0.0086],\n",
       "        [-0.9589,  0.2300,  0.0108],\n",
       "        [-0.2794,  0.2749,  0.0129],\n",
       "        [ 0.6570,  0.3192,  0.0151],\n",
       "        [ 0.9894,  0.3629,  0.0172],\n",
       "        [ 0.4121,  0.4057,  0.0194]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "even_pe = torch.sin(position / torch.pow(10000, even_i/d_model))\n",
    "even_pe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "34298983",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-05T14:35:14.475325Z",
     "start_time": "2023-09-05T14:35:14.461627Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.0000,  1.0000,  1.0000],\n",
       "        [ 0.5403,  0.9989,  1.0000],\n",
       "        [-0.4161,  0.9957,  1.0000],\n",
       "        [-0.9900,  0.9903,  1.0000],\n",
       "        [-0.6536,  0.9828,  1.0000],\n",
       "        [ 0.2837,  0.9732,  0.9999],\n",
       "        [ 0.9602,  0.9615,  0.9999],\n",
       "        [ 0.7539,  0.9477,  0.9999],\n",
       "        [-0.1455,  0.9318,  0.9999],\n",
       "        [-0.9111,  0.9140,  0.9998]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "odd_pe = torch.cos(position / torch.pow(10000, (odd_i - 1)/d_model))\n",
    "odd_pe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "07424051",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-05T14:36:05.184665Z",
     "start_time": "2023-09-05T14:36:05.172853Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([10, 3, 2])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stacked = torch.stack([even_pe, odd_pe], dim=2)\n",
    "stacked.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "dd89cf76",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-05T14:36:15.237656Z",
     "start_time": "2023-09-05T14:36:15.223865Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],\n",
       "        [ 0.8415,  0.5403,  0.0464,  0.9989,  0.0022,  1.0000],\n",
       "        [ 0.9093, -0.4161,  0.0927,  0.9957,  0.0043,  1.0000],\n",
       "        [ 0.1411, -0.9900,  0.1388,  0.9903,  0.0065,  1.0000],\n",
       "        [-0.7568, -0.6536,  0.1846,  0.9828,  0.0086,  1.0000],\n",
       "        [-0.9589,  0.2837,  0.2300,  0.9732,  0.0108,  0.9999],\n",
       "        [-0.2794,  0.9602,  0.2749,  0.9615,  0.0129,  0.9999],\n",
       "        [ 0.6570,  0.7539,  0.3192,  0.9477,  0.0151,  0.9999],\n",
       "        [ 0.9894, -0.1455,  0.3629,  0.9318,  0.0172,  0.9999],\n",
       "        [ 0.4121, -0.9111,  0.4057,  0.9140,  0.0194,  0.9998]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# (10, 3, 2) => (10, 6)\n",
    "pe = torch.flatten(stacked, start_dim=1, end_dim=2)\n",
    "pe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1d61e13a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-05T14:36:44.735730Z",
     "start_time": "2023-09-05T14:36:44.727093Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8415 0.5403 0.0464 0.9989 0.0022 1.0000 "
     ]
    }
   ],
   "source": [
    "t = 1\n",
    "for i in range(d_model):\n",
    "    if i % 2 == 0:\n",
    "        print(f'{np.sin(t / 10000**(i/d_model)):.4f}', end=' ')\n",
    "    else:\n",
    "        print(f'{np.cos(t / 10000**((i-1)/d_model)):.4f}', end=' ')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dafe167e",
   "metadata": {},
   "source": [
    "## from scartch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ab9abe62",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-05T14:37:11.678987Z",
     "start_time": "2023-09-05T14:37:11.662692Z"
    }
   },
   "outputs": [],
   "source": [
    "class SinPositionEncoding(nn.Module):\n",
    "    def __init__(self, max_sequence_length, d_model, base=10000):\n",
    "        super().__init__()\n",
    "        self.max_sequence_length = max_sequence_length\n",
    "        self.d_model = d_model\n",
    "        self.base = base\n",
    "    def forward(self):\n",
    "        even_i = torch.arange(0, self.d_model, 2).float()\n",
    "        odd_i = torch.arange(1, self.d_model, 2).float()\n",
    "        position = torch.arange(self.max_sequence_length, dtype=torch.float).reshape(-1, 1)\n",
    "        even_pe = torch.sin(position / torch.pow(self.base, even_i/self.d_model))\n",
    "        odd_pe = torch.cos(position / torch.pow(self.base, (odd_i-1)/self.d_model))\n",
    "        stacked = torch.stack([even_pe, odd_pe], dim=2)\n",
    "        return torch.flatten(stacked, start_dim=1, end_dim=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "60f8fe62",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-05T14:37:19.202775Z",
     "start_time": "2023-09-05T14:37:19.196722Z"
    }
   },
   "outputs": [],
   "source": [
    "spe = SinPositionEncoding(max_sequence_length=10, d_model=6)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2118ee03",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-05T14:37:20.449009Z",
     "start_time": "2023-09-05T14:37:20.437105Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],\n",
       "        [ 0.8415,  0.5403,  0.0464,  0.9989,  0.0022,  1.0000],\n",
       "        [ 0.9093, -0.4161,  0.0927,  0.9957,  0.0043,  1.0000],\n",
       "        [ 0.1411, -0.9900,  0.1388,  0.9903,  0.0065,  1.0000],\n",
       "        [-0.7568, -0.6536,  0.1846,  0.9828,  0.0086,  1.0000],\n",
       "        [-0.9589,  0.2837,  0.2300,  0.9732,  0.0108,  0.9999],\n",
       "        [-0.2794,  0.9602,  0.2749,  0.9615,  0.0129,  0.9999],\n",
       "        [ 0.6570,  0.7539,  0.3192,  0.9477,  0.0151,  0.9999],\n",
       "        [ 0.9894, -0.1455,  0.3629,  0.9318,  0.0172,  0.9999],\n",
       "        [ 0.4121, -0.9111,  0.4057,  0.9140,  0.0194,  0.9998]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "spe()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
